# graph_utils/metrics.py
"""
graph_utils.metrics
-------------------

Centralized metric registry for all kNN, mid-shortcuts, and CTMC embedding code.

Defines:
- MetricSpec: captures per-metric code, backend aliases, pre/post transforms.
- resolve(): single point of truth for alias mapping and metric logic.

To add a new metric:
1. Implement a function _newmetric(), following the signature of _euclidean or _cosine.
2. Add to _REG with a unique code.
"""

from dataclasses import dataclass, field 
import numpy as np
from typing import Dict, Tuple, Any, Callable

@dataclass(frozen=True)
class MetricSpec:
    name: str
    code: int
    backends: Dict[str, str | None]
    needs_pre: bool = False
    pre: Callable = field(default=lambda X, p: (X, {}))
    post: Callable = field(default=lambda d, b: d)
    mid_extras: Callable = field(default=lambda X, p: {})

# ---- built-ins ----
def _euclidean():
    def post(d, backend):  # sklearn returns sqrt, others usually squared
        return d if backend in ("faiss","hnswlib") else d * d
    return MetricSpec(
        name="euclidean", code=0,
        backends={"faiss":"l2", "hnswlib":"l2", "sklearn":"euclidean", "annoy":"euclidean", "pynndescent":"euclidean"},
        needs_pre=False,
        post=post
    )

def _cosine(eps=1e-12):
    def pre(X,p):
        norms = np.linalg.norm(X, axis=1).astype(np.float32)
        return X.astype(np.float32, copy=False), {"norms": np.maximum(norms, eps)}
    def mid_extras(X, p):
        norms = np.linalg.norm(X, axis=1).astype(np.float32)
        return {"norms": np.maximum(norms, eps)}
    return MetricSpec(
        name="cosine", code=1,
        backends={"faiss":None, "hnswlib":"cosine", "sklearn":"cosine", "annoy":"angular", "pynndescent":"cosine"},
        needs_pre=False,
        pre=pre, mid_extras=mid_extras
    )


def _mahalanobis():
    # whiten so downstream uses L2
    def pre(X,p):
        M = np.asarray(p.get("M") or p.get("VI"), dtype=np.float32)
        if M is None:
            raise ValueError("mahalanobis requires metric_params['M'] or ['VI']")
        # robust ‘sqrt’: chol if PSD, else eig
        try:
            L = np.linalg.cholesky(M).astype(np.float32)
        except np.linalg.LinAlgError:
            w,V = np.linalg.eigh(M.astype(np.float64))
            w = np.clip(w, 1e-8, None)
            L = (V @ np.diag(np.sqrt(w)) @ V.T).astype(np.float32)
        return (X.astype(np.float32) @ L.T).astype(np.float32), {}
    def post(d, backend):  # L2 after whitening
        return d if backend in ("faiss","hnswlib") else d * d
    return MetricSpec(
        name="mahalanobis", code=2,
        backends={"faiss":"l2","hnswlib":"l2","sklearn":"euclidean","annoy":"euclidean","pynndescent":"euclidean"},
        needs_pre=True,
        pre=pre, post=post
    )

_REG = {"euclidean": _euclidean(), "cosine": _cosine(), "mahalanobis": _mahalanobis()}

def register_metric(spec: MetricSpec) -> None:
    """Register or override a metric at runtime."""
    _REG[spec.name.lower()] = spec

    
def resolve(name: str | None, params: Dict[str, Any] | None) -> MetricSpec:
    m = _REG.get((name or "euclidean").lower())
    if m is None:
        raise ValueError(f"Unknown metric: {name!r}")
    return m
